{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# T2. databundle 和 tokenizer 的基本使用\n",
    "\n",
    "&emsp; 1 &ensp; fastNLP 中 dataset 的延伸\n",
    "\n",
    "&emsp; &emsp; 1.1 &ensp; databundle 的概念与使用\n",
    "\n",
    "&emsp; 2 &ensp; fastNLP 中的 tokenizer\n",
    " \n",
    "&emsp; &emsp; 2.1 &ensp; PreTrainedTokenizer 的概念\n",
    "\n",
    "&emsp; &emsp; 2.2 &ensp; BertTokenizer 的基本使用\n",
    "<!--  \n",
    "&emsp; &emsp; 2.3 &ensp; 补充：GloVe 词嵌入的使用 -->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. fastNLP 中 dataset 的延伸\n",
    "\n",
    "### 1.1 databundle 的概念与使用\n",
    "\n",
    "在`fastNLP 1.0`中，在常用的数据加载模块`DataLoader`和数据集`DataSet`模块之间，还存在\n",
    "\n",
    "&emsp; 一个中间模块，即 **数据包 DataBundle 模块**，可以从`fastNLP.io`路径中导入该模块\n",
    "\n",
    "在`fastNLP 1.0`中，**一个 databundle 数据包包含若干 dataset 数据集和 vocabulary 词汇表**\n",
    "\n",
    "&emsp; 分别存储在`datasets`和`vocabs`两个变量中，所以了解`databundle`数据包之前\n",
    "\n",
    "需要首先**复习 dataset 数据集和 vocabulary 词汇表**，**下面的一串代码**，**你知道其大概含义吗？**\n",
    "\n",
    "<!-- 必要提示：`NG20`，全称[`News Group 20`](http://qwone.com/~jason/20Newsgroups/)，是一个新闻文本分类数据集，包含20个类别\n",
    "\n",
    "&emsp; 数据集包含训练集`'ng20_train.csv'`和测试集`'ng20_test.csv'`两部分，每条数据\n",
    "\n",
    "&emsp; 包括`'label'`标签和`'text'`文本两个条目，通过`sample(frac=1)[:6]`随机采样并读取前6条 -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------------------------------+----------+\n",
      "| text                                     | label    |\n",
      "+------------------------------------------+----------+\n",
      "| ['a', 'series', 'of', 'escapades', 'd... | negative |\n",
      "| ['this', 'quiet', ',', 'introspective... | positive |\n",
      "| ['even', 'fans', 'of', 'ismail', 'mer... | negative |\n",
      "| ['the', 'importance', 'of', 'being', ... | neutral  |\n",
      "+------------------------------------------+----------+\n",
      "+------------------------------------------+----------+\n",
      "| text                                     | label    |\n",
      "+------------------------------------------+----------+\n",
      "| ['a', 'comedy-drama', 'of', 'nearly',... | positive |\n",
      "| ['a', 'positively', 'thrilling', 'com... | neutral  |\n",
      "+------------------------------------------+----------+\n",
      "{'<pad>': 0, '<unk>': 1, 'negative': 2, 'positive': 3, 'neutral': 4}\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from fastNLP import DataSet\n",
    "from fastNLP import Vocabulary\n",
    "from fastNLP.io import DataBundle\n",
    "\n",
    "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
    "datasets.rename_field('Sentence', 'text')\n",
    "datasets.rename_field('Sentiment', 'label')\n",
    "datasets.apply_more(lambda ins:{'label': ins['label'].lower(), \n",
    "                                'text': ins['text'].lower().split()},\n",
    "                    progress_bar='tqdm')\n",
    "datasets.delete_field('SentenceId')\n",
    "train_ds, test_ds = datasets.split(ratio=0.7)\n",
    "datasets = {'train': train_ds, 'test': test_ds}\n",
    "print(datasets['train'])\n",
    "print(datasets['test'])\n",
    "\n",
    "vocabs = {}\n",
    "vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n",
    "vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n",
    "print(vocabs['label'].word2idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!-- 上述代码的含义是：随机读取`NG20`数据集中的各十条训练数据和测试数据，将标签都设为小写，对文本进行分词\n",
    " -->\n",
    "上述代码的含义是：从`test4dataset`的 6 条数据中，划分 4 条训练集（`int(6*0.7) = 4`），2 条测试集\n",
    "\n",
    "&emsp; &emsp; 修改相关字段名称，删除序号字段，同时将标签都设为小写，对文本进行分词\n",
    "\n",
    "&emsp; 接着通过`concat`方法拼接测试集训练集，注意设置`inplace=False`，生成临时的新数据集\n",
    "\n",
    "&emsp; 使用`from_dataset`方法从拼接的数据集中抽取词汇表，为将数据集中的单词替换为序号做准备\n",
    "\n",
    "由此就可以得到**数据集字典 datasets**（**对应训练集、测试集**）和**词汇表字典 vocabs**（**对应数据集各字段**）\n",
    "\n",
    "&emsp; 然后就可以初始化`databundle`了，通过`print`可以观察其大致结构，效果如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In total 2 datasets:\n",
      "\ttrain has 4 instances.\n",
      "\ttest has 2 instances.\n",
      "In total 2 vocabs:\n",
      "\tlabel has 5 entries.\n",
      "\ttext has 96 entries.\n",
      "\n",
      "['train', 'test']\n",
      "['label', 'text']\n"
     ]
    }
   ],
   "source": [
    "data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n",
    "print(data_bundle)\n",
    "print(data_bundle.get_dataset_names())\n",
    "print(data_bundle.get_vocab_names())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，也可以通过`data_bundle`的`num_dataset`和`num_vocab`返回数据表和词汇表个数\n",
    "\n",
    "&emsp; 通过`data_bundle`的`iter_datasets`和`iter_vocabs`遍历数据表和词汇表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In total 2 datasets:\n",
      "\ttrain has 4 instances.\n",
      "\ttest has 2 instances.\n",
      "In total 2 datasets:\n",
      "\tlabel has 5 entries.\n",
      "\ttext has 96 entries.\n"
     ]
    }
   ],
   "source": [
    "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
    "for name, dataset in data_bundle.iter_datasets():\n",
    "    print(\"\\t%s has %d instances.\" % (name, len(dataset)))\n",
    "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
    "for name, vocab in data_bundle.iter_vocabs():\n",
    "    print(\"\\t%s has %d entries.\" % (name, len(vocab)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在数据包`databundle`中，也有和数据集`dataset`类似的四个`apply`函数，即\n",
    "\n",
    "&emsp; `apply`函数、`apply_field`函数、`apply_field_more`函数和`apply_more`函数\n",
    "\n",
    "&emsp; 负责对数据集进行预处理，如下所示是`apply_more`函数的示例，其他函数类似\n",
    "\n",
    "此外，通过`get_dataset`函数，可以通过数据表名`name`称找到对应数据表\n",
    "\n",
    "&emsp; 通过`get_vocab`函数，可以通过词汇表名`field_name`称找到对应词汇表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------------------+----------+-----+\n",
      "| text                         | label    | len |\n",
      "+------------------------------+----------+-----+\n",
      "| ['a', 'series', 'of', 'es... | negative | 37  |\n",
      "| ['this', 'quiet', ',', 'i... | positive | 11  |\n",
      "| ['even', 'fans', 'of', 'i... | negative | 21  |\n",
      "| ['the', 'importance', 'of... | neutral  | 20  |\n",
      "+------------------------------+----------+-----+\n"
     ]
    }
   ],
   "source": [
    "data_bundle.apply_more(lambda ins:{'len': len(ins['text'])}, progress_bar='tqdm')\n",
    "print(data_bundle.get_dataset('train'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. fastNLP 中的 tokenizer\n",
    "\n",
    "### 2.1 PreTrainTokenizer 的提出\n",
    "<!-- \n",
    "*词嵌入是什么，为什么不用了*\n",
    "\n",
    "*什么是字节对编码，BPE的提出*\n",
    "\n",
    "*以BERT模型为例，WordPiece的提出*\n",
    " -->\n",
    "在`fastNLP 1.0`中，**使用 PreTrainedTokenizer 模块来为数据集中的词语进行词向量的标注**\n",
    "\n",
    "&emsp; 需要注意的是，`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了 transformers 模块**\n",
    "\n",
    "&emsp; 这是因为 `fastNLP 1.0`中`PreTrainedTokenizer`模块的实现基于`Huggingface Transformers`库\n",
    "\n",
    "**Huggingface Transformers 是一个开源的**，**基于 transformer 模型结构提供的预训练语言库**\n",
    "\n",
    "&emsp; 包含了多种经典的基于`transformer`的预训练模型，如`BERT`、`BART`、`RoBERTa`、`GPT2`、`CPT`\n",
    "\n",
    "&emsp; 更多相关内容可以参考`Huggingface Transformers`的[相关论文](https://arxiv.org/pdf/1910.03771.pdf)、[官方文档](https://huggingface.co/transformers/)以及[的代码仓库](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 BertTokenizer 的基本使用\n",
    "\n",
    "在`fastNLP 1.0`中，以`PreTrainedTokenizer`为基类，泛化出多个子类，实现基于`BERT`等模型的标注\n",
    "\n",
    "&emsp; 本节以`BertTokenizer`模块为例，展示`PreTrainedTokenizer`模块的使用方法与应用实例\n",
    "\n",
    "**BertTokenizer 的初始化包括 导入模块和导入数据 两步**，先通过从`fastNLP.transformers.torch`中\n",
    "\n",
    "&emsp; 导入`BertTokenizer`模块，再**通过 from_pretrained 方法指定 tokenizer 参数类型下载**\n",
    "\n",
    "&emsp; 其中，**'bert-base-uncased' 指定 tokenizer 使用的预训练 BERT 类型**：单词不区分大小写\n",
    "\n",
    "&emsp; &emsp; **模块层数 L=12**，**隐藏层维度 H=768**，**自注意力头数 A=12**，**总参数量 110M**\n",
    "\n",
    "&emsp; 另外，模型参数自动下载至 home 目录下的`~\\.cache\\huggingface\\transformers`文件夹中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from fastNLP.transformers.torch import BertTokenizer\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过变量`vocab_size`和`vocab_files_names`可以查看`BertTokenizer`的词汇表的大小和对应文件\n",
    "\n",
    "&emsp; 通过变量`vocab`可以访问`BertTokenizer`预训练的词汇表（由于内容过大就不演示了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30522 {'vocab_file': 'vocab.txt'}\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.vocab_size, tokenizer.vocab_files_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过变量`all_special_tokens`或通过变量`special_tokens_map`可以**查看 BertTokenizer 内置的特殊词素**\n",
    "\n",
    "&emsp; 包括**未知符 '[UNK]'**, **断句符 '[SEP]'**, **补零符 '[PAD]'**, **分类符 '[CLS]'**, **掩码 '[MASK]'**\n",
    "\n",
    "通过变量`all_special_ids`可以**查看 BertTokenizer 内置的特殊词素对应的词汇表编号**，相同功能\n",
    "\n",
    "&emsp; 也可以直接通过查看`pad_token`，值为`'[UNK]'`，和`pad_token_id`，值为`0`，等变量来实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pad_token [PAD] 0\n",
      "unk_token [UNK] 100\n",
      "cls_token [CLS] 101\n",
      "sep_token [SEP] 102\n",
      "msk_token [MASK] 103\n",
      "all_tokens ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 102, 0, 101, 103]\n",
      "{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
     ]
    }
   ],
   "source": [
    "print('pad_token', tokenizer.pad_token, tokenizer.pad_token_id) \n",
    "print('unk_token', tokenizer.unk_token, tokenizer.unk_token_id) \n",
    "print('cls_token', tokenizer.cls_token, tokenizer.cls_token_id) \n",
    "print('sep_token', tokenizer.sep_token, tokenizer.sep_token_id)\n",
    "print('msk_token', tokenizer.mask_token, tokenizer.mask_token_id)\n",
    "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
    "print(tokenizer.special_tokens_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，还可以添加其他特殊字符，例如起始符`[BOS]`、终止符`[EOS]`，添加后词汇表编号也会相应改变\n",
    "\n",
    "&emsp; *但是如何添加这两个之外的字符，并且如何将这两个的编号设置为 [UNK] 之外的编号？？？*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bos_token [BOS] 100\n",
      "eos_token [EOS] 100\n",
      "all_tokens ['[BOS]', '[EOS]', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 100, 100, 102, 0, 101, 103]\n",
      "{'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
     ]
    }
   ],
   "source": [
    "tokenizer.bos_token = '[BOS]'\n",
    "tokenizer.eos_token = '[EOS]'\n",
    "# tokenizer.bos_token_id = 104\n",
    "# tokenizer.eos_token_id = 105\n",
    "print('bos_token', tokenizer.bos_token, tokenizer.bos_token_id)\n",
    "print('eos_token', tokenizer.eos_token, tokenizer.eos_token_id)\n",
    "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
    "print(tokenizer.special_tokens_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在`BertTokenizer`中，**使用 tokenize 函数和 convert_tokens_to_string 函数可以实现文本和词素列表的互转**\n",
    "\n",
    "&emsp; 此外，**使用 convert_tokens_to_ids 函数和 convert_ids_to_tokens 函数则可以实现词素和词素编号的互转**\n",
    "\n",
    "&emsp; 上述四个函数的使用效果如下所示，此处可以明显看出，`tokenizer`分词和传统分词的不同效果，例如`'##cap'`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012]\n",
      "['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
      "a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .\n"
     ]
    }
   ],
   "source": [
    "text = \"a series of escapades demonstrating the adage that what is \" \\\n",
    "       \"good for the goose is also good for the gander , some of which \" \\\n",
    "       \"occasionally amuses but none of which amounts to much of a story .\" \n",
    "tks = ['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', \n",
    "       'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', \n",
    "       'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', \n",
    "       'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', \n",
    "       'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
    "ids = [ 1037,  2186,  1997,  9686, 17695, 18673, 14313,  1996, 15262,  3351, \n",
    "        2008,  2054,  2003,  2204,  2005,  1996, 13020,  2003,  2036,  2204,\n",
    "        2005,  1996, 25957,  4063,  1010,  2070,  1997,  2029,  5681,  2572,\n",
    "       25581,  2021,  3904,  1997,  2029,  8310,  2000,  2172,  1997,  1037,\n",
    "        2466,  1012]\n",
    "\n",
    "tokens = tokenizer.tokenize(text)\n",
    "print(tokenizer.convert_tokens_to_ids(tokens))\n",
    "\n",
    "ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "print(tokenizer.convert_ids_to_tokens(ids))\n",
    "\n",
    "print(tokenizer.convert_tokens_to_string(tokens))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在`BertTokenizer`中，还有另外两个函数可以实现分词标注，分别是 **encode 和 decode 函数**，**可以直接实现**\n",
    "\n",
    "&emsp; **文本字符串和词素编号列表的互转**，但是编码过程中会按照`BERT`的规则，**在句子首末加入 [CLS] 和 [SEP]**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012, 102]\n",
      "[CLS] a series of escapades demonstrating the adage that what is good for the goose is also good for the gander, some of which occasionally amuses but none of which amounts to much of a story. [SEP]\n"
     ]
    }
   ],
   "source": [
    "enc = tokenizer.encode(text)\n",
    "print(tokenizer.encode(text))\n",
    "dec = tokenizer.decode(enc)\n",
    "print(tokenizer.decode(enc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在`encode`函数之上，还有`encode_plus`函数，这也是在数据预处理中，`BertTokenizer`模块最常用到的函数\n",
    "\n",
    "&emsp; **encode 函数的参数**，**encode_plus 函数都有**；**encode 函数词素编号列表**，**encode_plus 函数返回字典**\n",
    "\n",
    "在`encode_plus`函数的返回值中，字段`input_ids`表示词素编号，其余两个字段后文有详细解释\n",
    "\n",
    "&emsp; **字段 token_type_ids 详见 text_pairs 的示例**，**字段 attention_mask 详见 batch_text 的示例**\n",
    "\n",
    "在`encode_plus`函数的参数中，参数`add_special_tokens`表示是否按照`BERT`的规则，加入相关特殊字符\n",
    "\n",
    "&emsp; 参数`max_length`表示句子截取最大长度（算特殊字符），在参数`truncation=True`时会自动截取\n",
    "\n",
    "&emsp; 参数`return_attention_mask`约定返回的字典中是否包括`attention_mask`字段，以上案例如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
     ]
    }
   ],
   "source": [
    "text = \"a series of escapades demonstrating the adage that what is good for the goose is also good for \"\\\n",
    "       \"the gander , some of which occasionally amuses but none of which amounts to much of a story .\" \n",
    "\n",
    "encoded = tokenizer.encode_plus(text=text, add_special_tokens=True, max_length=32, \n",
    "                                truncation=True, return_attention_mask=True)\n",
    "print(encoded)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在`encode_plus`函数之上，还有`batch_encode_plus`函数（类似地，在`decode`之上，还有`batch_decode`\n",
    "\n",
    "&emsp; 两者参数类似，**batch_encode_plus 函数针对批量文本 batch_text**，**或者批量句对 text_pairs**\n",
    "\n",
    "在针对批量文本`batch_text`的示例中，注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
    "\n",
    "&emsp; 可以发现，**attention_mask 字段通过 01 标注出词素序列中该位置是否为补零**，可以用做自注意力的掩模"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 0, 0], [101, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 0, 0, 0, 0, 0, 0, 0], [101, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]}\n"
     ]
    }
   ],
   "source": [
    "batch_text = [\"a series of escapades demonstrating the adage that\",\n",
    "              \"what is good for the goose is also good for the gander\",\n",
    "              \"some of which occasionally amuses\",\n",
    "              \"but none of which amounts to much of a story\" ]\n",
    "\n",
    "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text, padding=True,\n",
    "                                      add_special_tokens=True, max_length=16, truncation=True, \n",
    "                                      return_attention_mask=True)\n",
    "print(encoded)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "而在针对批量句对`text_pairs`的示例中，注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
    "\n",
    "&emsp; 可以发现，**token_type_ids 字段通过 01 标注出词素序列中该位置为句对中的第几句**，句对用 [SEP] 分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]}\n"
     ]
    }
   ],
   "source": [
    "text_pairs = [(\"a series of escapades demonstrating the adage that\",\n",
    "               \"what is good for the goose is also good for the gander\"),\n",
    "              (\"some of which occasionally amuses\",\n",
    "               \"but none of which amounts to much of a story\")]\n",
    "\n",
    "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=text_pairs, padding=True,\n",
    "                                      add_special_tokens=True, max_length=32, truncation=True, \n",
    "                                      return_attention_mask=True)\n",
    "print(encoded)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "回到`encode_plus`上，在接下来的示例中，**使用内置的 functools.partial 模块构造 encode 函数**\n",
    "\n",
    "&emsp; 接着**使用该函数对 databundle 进行数据预处理**，由于`tokenizer.encode_plus`返回的是一个字典\n",
    "\n",
    "&emsp; 读入的是一个字段，所以此处使用`apply_field_more`方法，得到结果自动并入`databundle`中如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "functools.partial(<bound method PreTrainedTokenizerBase.encode_plus of PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})>, max_length=32, truncation=True, return_attention_mask=True)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+----------+-----+------------------+--------------------+--------------------+\n",
      "| text             | label    | len | input_ids        | token_type_ids     | attention_mask     |\n",
      "+------------------+----------+-----+------------------+--------------------+--------------------+\n",
      "| ['a', 'series... | negative | 37  | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
      "| ['this', 'qui... | positive | 11  | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
      "| ['even', 'fan... | negative | 21  | [101, 2130, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
      "| ['the', 'impo... | neutral  | 20  | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
      "+------------------+----------+-----+------------------+--------------------+--------------------+\n"
     ]
    }
   ],
   "source": [
    "from functools import partial\n",
    "\n",
    "encode = partial(tokenizer.encode_plus, max_length=32, truncation=True,\n",
    "                 return_attention_mask=True)\n",
    "print(encode)\n",
    "\n",
    "data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n",
    "print(data_bundle.datasets['train'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "经过`tokenizer`的处理，原始数据集中的文本被替换为词素编号列表，此时，调用`databundle`模块的\n",
    "\n",
    "&emsp; **set_pad 函数**，**将 databundle 的补零符编号 pad_val 和 tokenizer 补零符编号 pad_token_id 统一**\n",
    "\n",
    "&emsp; 该函数同时将`databundle`的`'input_ids'`字段添加到对应数据集的`collator`中（见`tutorial 3.`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "{}\n",
      "{'input_ids': {'pad_val': 0, 'dtype': None, 'backend': 'auto', 'pad_fn': None}}\n",
      "{'input_ids': {'pad_val': 0, 'dtype': None, 'backend': 'auto', 'pad_fn': None}}\n"
     ]
    }
   ],
   "source": [
    "print(data_bundle.get_dataset('train').collator.input_fields)\n",
    "print(data_bundle.get_dataset('test').collator.input_fields)\n",
    "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
    "print(data_bundle.get_dataset('train').collator.input_fields)\n",
    "print(data_bundle.get_dataset('test').collator.input_fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，使用`from_dataset`、`index_dataset`和`iter_datasets`方法，为处理数据集的`'label'`字段编码\n",
    "\n",
    "&emsp; 接着**通过 set_ignore 函数**，**指定 databundle 的部分字段**，如`'text'`等，**在划分 batch 时不再出现**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n",
      "| text           | label    | len | input_ids      | token_type_ids     | attention_mask     | target |\n",
      "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n",
      "| ['a', 'seri... | negative | 37  | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0      |\n",
      "| ['this', 'q... | positive | 11  | [101, 2023,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1      |\n",
      "| ['even', 'f... | negative | 21  | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0      |\n",
      "| ['the', 'im... | neutral  | 20  | [101, 1996,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2      |\n",
      "+----------------+----------+-----+----------------+--------------------+--------------------+--------+\n"
     ]
    }
   ],
   "source": [
    "target_vocab = Vocabulary(padding=None, unknown=None)\n",
    "\n",
    "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
    "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
    "                           new_field_name='target')\n",
    "\n",
    "data_bundle.set_ignore('text', 'len', 'label') \n",
    "print(data_bundle.datasets['train'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上就是使用`dataset`、`vocabulary`、`databundle`和`tokenizer`实现输入文本数据的读取\n",
    "\n",
    "&emsp; 分词标注、序列化的全部预处理过程，通过下方的代码梳理，相信你会有更详细的了解\n",
    "\n",
    "```python\n",
    "# 首先，导入预训练的 BertTokenizer，这里使用 'bert-base-uncased' 版本\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# 接着，导入数据，先生成为 dataset 形式，再变成 dataset-dict，并转为 databundle 形式\n",
    "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
    "train_ds, test_ds = datasets.split(ratio=0.7)\n",
    "data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n",
    "\n",
    "# 然后，通过 tokenizer.encode_plus 函数，进行文本分词标注、修改并补充数据包内容\n",
    "encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
    "                 return_attention_mask=True)\n",
    "data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
    "\n",
    "# 在修改好 'text' 字段的文本信息后，接着处理 'label' 字段的预测信息\n",
    "target_vocab = Vocabulary(padding=None, unknown=None)\n",
    "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
    "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
    "                           new_field_name='target')\n",
    "\n",
    "# 最后，通过 data_bundle 的其他一些函数，完成善后内容\n",
    "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
    "data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence')  \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "<!-- ### 2.3 补充：GloVe 词嵌入的使用\n",
    "\n",
    "如何使用传统的GloVe词嵌入\n",
    "\n",
    "from utils import get_from_cache\n",
    "\n",
    "filepath = get_from_cache(\"http://download.fastnlp.top/embedding/glove.6B.50d.zip\") -->\n",
    "\n",
    "在接下来的`tutorial 3.`中，将会介绍`fastNLP v1.0`中的`dataloader`模块，会涉及本章中\n",
    "\n",
    "&emsp; 提到的`collator`模块，`fastNLP`的多框架适应以及完整的数据加载过程，敬请期待"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
